Consolidate BTYD Models
1 Load Transactional Datasets
We first want to load the real-world transactional dataset.
1.1 Load Pre-processed Transactional Data
customer_cohortdata_tbl <- read_rds("data/customer_cohort_tbl.rds")
customer_cohortdata_tbl %>% glimpse()## Rows: 5,852
## Columns: 5
## $ customer_id <chr> "12346", "12347", "12348", "12349", "12350", "12351", …
## $ cohort_qtr <yearqtr> 2010 Q1, 2010 Q4, 2010 Q3, 2010 Q2, 2011 Q1, 2010 …
## $ cohort_ym <chr> "2010 03", "2010 10", "2010 09", "2010 04", "2011 02",…
## $ first_tnx_date <date> 2010-03-02, 2010-10-31, 2010-09-27, 2010-04-29, 2011-…
## $ total_tnx_count <int> 3, 8, 5, 3, 1, 1, 9, 2, 1, 2, 6, 2, 5, 10, 6, 4, 10, 2…
We also want to load the raw transaction data as we want to transform the data into a form we now use.
retail_transaction_data_tbl <- read_rds("data/retail_data_cleaned_tbl.rds")
retail_transaction_data_tbl %>% glimpse()## Rows: 1,021,424
## Columns: 23
## $ row_id <chr> "ROW0000001", "ROW0000002", "ROW0000003", "ROW000000…
## $ excel_sheet <chr> "Year 2009-2010", "Year 2009-2010", "Year 2009-2010"…
## $ invoice_id <chr> "489434", "489434", "489434", "489434", "489434", "4…
## $ stock_code <chr> "85048", "79323P", "79323W", "22041", "21232", "2206…
## $ description <chr> "15CM CHRISTMAS GLASS BALL 20 LIGHTS", "PINK CHERRY …
## $ quantity <dbl> 12, 12, 12, 48, 24, 24, 24, 10, 12, 12, 24, 12, 10, …
## $ invoice_date <date> 2009-12-01, 2009-12-01, 2009-12-01, 2009-12-01, 200…
## $ price <dbl> 6.95, 6.75, 6.75, 2.10, 1.25, 1.65, 1.25, 5.95, 2.55…
## $ customer_id <chr> "13085", "13085", "13085", "13085", "13085", "13085"…
## $ country <chr> "United Kingdom", "United Kingdom", "United Kingdom"…
## $ stock_code_upr <chr> "85048", "79323P", "79323W", "22041", "21232", "2206…
## $ cancellation <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FAL…
## $ invoice_dttm <dttm> 2009-12-01 07:45:00, 2009-12-01 07:45:00, 2009-12-0…
## $ invoice_month <chr> "December", "December", "December", "December", "Dec…
## $ invoice_dow <chr> "Tuesday", "Tuesday", "Tuesday", "Tuesday", "Tuesday…
## $ invoice_dom <chr> "01", "01", "01", "01", "01", "01", "01", "01", "01"…
## $ invoice_hour <chr> "07", "07", "07", "07", "07", "07", "07", "07", "07"…
## $ invoice_minute <chr> "45", "45", "45", "45", "45", "45", "45", "45", "45"…
## $ invoice_woy <chr> "49", "49", "49", "49", "49", "49", "49", "49", "49"…
## $ invoice_ym <chr> "200912", "200912", "200912", "200912", "200912", "2…
## $ stock_value <dbl> 83.40, 81.00, 81.00, 100.80, 30.00, 39.60, 30.00, 59…
## $ invoice_monthprop <dbl> 0.04347826, 0.04347826, 0.04347826, 0.04347826, 0.04…
## $ exclude <lgl> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FAL…
We need to aggregate this data up into a form to match our synthetic data, so
we aggregate transactions by invoice_id.
customer_transactions_tbl <- retail_transaction_data_tbl %>%
drop_na(customer_id) %>%
filter(exclude = TRUE) %>%
group_by(tnx_timestamp = invoice_dttm, customer_id, invoice_id) %>%
summarise(
.groups = "drop",
total_spend = sum(stock_value)
) %>%
filter(total_spend > 0) %>%
arrange(tnx_timestamp, customer_id)
customer_transactions_tbl %>% glimpse()## Rows: 37,031
## Columns: 4
## $ tnx_timestamp <dttm> 2009-12-01 07:45:00, 2009-12-01 07:45:59, 2009-12-01 09…
## $ customer_id <chr> "13085", "13085", "13078", "15362", "18102", "12682", "1…
## $ invoice_id <chr> "489434", "489435", "489436", "489437", "489438", "48943…
## $ total_spend <dbl> 505.30, 145.80, 630.33, 310.75, 2286.24, 426.30, 50.40, …
We re-produce the visualisation of the transaction times we used in previous workbooks.
plot_tbl <- customer_transactions_tbl %>%
group_nest(customer_id, .key = "cust_data") %>%
filter(map_int(cust_data, nrow) > 3) %>%
slice_sample(n = 30) %>%
unnest(cust_data)
ggplot(plot_tbl, aes(x = tnx_timestamp, y = customer_id)) +
geom_line() +
geom_point() +
labs(
x = "Date",
y = "Customer ID",
title = "Visualisation of Customer Transaction Times"
) +
theme(axis.text.y = element_text(size = 10))2 Fit the Fixed Prior P/NBD Model
stan_modeldir <- "stan_models"
stan_codedir <- "stan_code"We first need to construct our fitted dataset from this external data.
In terms of choosing a cut-off point, we will consider all transactions up to and including March 31, 2011.
btyd_fitdata_tbl <- customer_transactions_tbl %>%
calculate_transaction_cbs_data(last_date = as.POSIXct("2011-01-01"))
btyd_fitdata_tbl %>% glimpse()## Rows: 4,342
## Columns: 6
## $ customer_id <chr> "12346", "12347", "12348", "12349", "12351", "12352", "…
## $ first_tnx_date <dttm> 2009-12-14 08:34:00, 2010-10-31 14:19:59, 2010-09-27 1…
## $ last_tnx_date <dttm> 2010-06-28 13:53:00, 2010-12-07 14:56:59, 2010-12-16 1…
## $ x <dbl> 10, 1, 1, 2, 0, 1, 0, 0, 2, 1, 2, 5, 4, 2, 0, 0, 0, 2, …
## $ t_x <dbl> 28.03164683, 5.28938492, 11.45337302, 25.97053571, 0.00…
## $ T_cal <dbl> 54.663294, 8.771825, 13.625099, 35.206349, 4.622718, 7.…
We also want to construct some summary statistics for the data after that.
btyd_obs_stats_tbl <- customer_transactions_tbl %>%
filter(
tnx_timestamp >= as.POSIXct("2011-01-01")
) %>%
group_by(customer_id) %>%
summarise(
.groups = "drop",
tnx_count = n(),
first_tnx = min(tnx_timestamp),
last_tnx = max(tnx_timestamp)
)
btyd_obs_stats_tbl %>% glimpse()## Rows: 4,219
## Columns: 4
## $ customer_id <chr> "12346", "12347", "12348", "12349", "12350", "12352", "123…
## $ tnx_count <int> 1, 6, 3, 1, 1, 8, 1, 1, 1, 3, 1, 2, 4, 3, 1, 10, 2, 4, 2, …
## $ first_tnx <dttm> 2011-01-18 10:00:59, 2011-01-26 14:29:59, 2011-01-25 10:4…
## $ last_tnx <dttm> 2011-01-18 10:00:59, 2011-12-07 15:51:59, 2011-09-25 13:1…
We now compile this model using CmdStanR.
pnbd_consol_fixed_stanmodel <- cmdstan_model(
"stan_code/pnbd_fixed.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)2.1 Fit the Model
We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "pnbd_consol_fixed"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- btyd_fitdata_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 1.00,
mu_mn = 0.05,
mu_cv = 1.00,
)
pnbd_consol_fixed_stanfit <- pnbd_consol_fixed_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4201,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 110.1 seconds.
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 111.2 seconds.
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 112.6 seconds.
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 113.1 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 111.8 seconds.
## Total execution time: 113.4 seconds.
pnbd_consol_fixed_stanfit$summary()## # A tibble: 13,027 × 10
## variable mean median sd mad q5 q95 rhat ess_b…¹
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -7.53e+4 -7.53e+4 73.3 76.1 -7.54e+4 -7.52e+4 1.00 631.
## 2 lambda[1] 3.12e-1 3.02e-1 0.101 0.0970 1.71e-1 5.00e-1 1.00 1817.
## 3 lambda[2] 1.63e-1 1.33e-1 0.118 0.0986 2.73e-2 3.78e-1 1.00 2479.
## 4 lambda[3] 1.14e-1 9.90e-2 0.0777 0.0701 2.10e-2 2.74e-1 1.00 1652.
## 5 lambda[4] 8.04e-2 6.97e-2 0.0489 0.0407 2.24e-2 1.76e-1 1.00 1679.
## 6 lambda[5] 1.30e-1 8.62e-2 0.131 0.0880 9.28e-3 4.03e-1 0.999 1847.
## 7 lambda[6] 2.02e-1 1.65e-1 0.152 0.124 3.33e-2 4.85e-1 1.00 1705.
## 8 lambda[7] 1.13e-1 7.09e-2 0.129 0.0772 4.89e-3 3.56e-1 1.00 1560.
## 9 lambda[8] 1.17e-1 6.14e-2 0.154 0.0731 3.04e-3 4.39e-1 1.01 1821.
## 10 lambda[9] 2.07e-1 1.85e-1 0.118 0.105 6.06e-2 4.22e-1 1.00 1549.
## # … with 13,017 more rows, 1 more variable: ess_tail <dbl>, and abbreviated
## # variable name ¹ess_bulk
We have some basic HMC-based validity statistics we can check.
pnbd_consol_fixed_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_consol_fixed-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_fixed-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_fixed-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_fixed-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
2.2 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
parameter_subset <- c(
"lambda[1]", "lambda[2]", "lambda[3]", "lambda[4]",
"mu[1]", "mu[2]", "mu[3]", "mu[4]"
)
pnbd_consol_fixed_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = parameter_subset) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and Mu Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
pnbd_consol_fixed_stanfit %>%
rhat(pars = c("lambda", "mu")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
pnbd_consol_fixed_stanfit %>%
neff_ratio(pars = c("lambda", "mu")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
pnbd_consol_fixed_stanfit$draws() %>%
mcmc_acf(pars = parameter_subset) +
ggtitle("Autocorrelation Plot of Sample Values")2.3 Validate the Fixed Prior Model
run_chunk <- function(sim_file, param_tbl)
run_pnbd_simulations_chunk(
sim_file, param_tbl,
start_dttm = as.POSIXct("2011-01-01"),
end_dttm = as.POSIXct("2011-12-10")
)
pnbd_consol_fixed_valid_lst <- construct_model_validation_data(
btyd_stanfit = pnbd_consol_fixed_stanfit,
btyd_fitdata_tbl = btyd_fitdata_tbl,
btyd_obs_stats_tbl = btyd_obs_stats_tbl,
precompute_dir = "precompute/pnbd_consol_fixed",
precompute_key = "sims_pnbd_consol_fixed",
run_chunk_func = run_chunk
)
pnbd_consol_fixed_valid_lst %>% glimpse(max.level = 1)## List of 4
## $ btyd_validation_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ btyd_validsims_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ valid_custcount_plot:List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
## $ valid_tnxcount_plot :List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
Having ran the simulations, we now want to check the outputs against the observed data.
pnbd_consol_fixed_valid_lst$valid_custcount_plot %>% print()pnbd_consol_fixed_valid_lst$valid_tnxcount_plot %>% print()2.4 Write to Disk
pnbd_consol_fixed_valid_lst %>% write_rds("data/pnbd_consol_fixed_valid_lst.rds")3 Fit the Hierarchical Lambda-Mean P/NBD Model
## functions {
## #include util_functions.stan
## }
##
## data {
## int<lower=1> n; // number of customers
##
## vector<lower=0>[n] t_x; // time to most recent purchase
## vector<lower=0>[n] T_cal; // total observation time
## vector<lower=0>[n] x; // number of purchases observed
##
## real<lower=0> lambda_cv; // prior cv for lambda
## real<lower=0> mu_mn; // prior mean for mu
## real<lower=0> mu_cv; // prior cv for mu
##
## real lambda_mn_p1; // hyperprior p1 for lambda mean
## real<lower=0> lambda_mn_p2; // hyperprior p2 for lambda mean
## }
##
## transformed data {
## real<lower=0> s = 1 / (mu_cv * mu_cv);
## real<lower=0> beta = 1 / (mu_cv * mu_cv * mu_mn);
## }
##
##
## parameters {
## real<lower=0> lambda_mn;
##
## vector<lower=0>[n] lambda; // purchase rate
## vector<lower=0>[n] mu; // lifetime dropout rate
## }
##
##
## transformed parameters {
## real<lower=0> r;
## real<lower=0> alpha;
##
## r = 1 / (lambda_cv * lambda_cv);
## alpha = 1 / (lambda_cv * lambda_cv * lambda_mn);
## }
##
## model {
## // model the hyper-prior
## lambda_mn ~ lognormal(lambda_mn_p1, lambda_mn_p2);
##
## // setting priors
## lambda ~ gamma(r, alpha);
## mu ~ gamma(s, beta);
##
## target += calculate_pnbd_loglik(n, lambda, mu, x, t_x, T_cal);
## }
##
## generated quantities {
## vector[n] p_alive; // Probability that they are still "alive"
##
## p_alive = 1 ./ (1 + mu ./ (mu + lambda) .* (exp((lambda + mu) .* (T_cal - t_x)) - 1));
## }
We now compile this model using CmdStanR.
pnbd_consol_lambmn_stanmodel <- cmdstan_model(
"stan_code/pnbd_hierlambdamn.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)3.1 Fit the Model
We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "pnbd_consol_lambmn"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- btyd_fitdata_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn_p1 = log(0.25) - 0.5 * (1.0)^2,
lambda_mn_p2 = 1,
lambda_cv = 1.00,
mu_mn = 0.05,
mu_cv = 0.60,
)
pnbd_consol_lambmn_stanfit <- pnbd_consol_lambmn_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4202,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 133.1 seconds.
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 137.5 seconds.
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 139.6 seconds.
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 142.5 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 138.2 seconds.
## Total execution time: 142.9 seconds.
pnbd_consol_lambmn_stanfit$summary()## # A tibble: 13,030 × 10
## variable mean median sd mad q5 q95 rhat ess_b…¹
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -1.00e+5 -1.00e+5 7.11e+1 6.89e+1 -1.00e+5 -9.98e+4 1.00 518.
## 2 lambda_mn 1.19e-1 1.19e-1 2.19e-3 2.09e-3 1.15e-1 1.22e-1 1.00 2096.
## 3 lambda[1] 2.76e-1 2.67e-1 8.62e-2 7.96e-2 1.54e-1 4.37e-1 1.00 2684.
## 4 lambda[2] 1.17e-1 9.73e-2 8.48e-2 7.26e-2 1.83e-2 2.80e-1 1.00 2304.
## 5 lambda[3] 9.18e-2 7.64e-2 6.66e-2 5.58e-2 1.56e-2 2.21e-1 0.999 2997.
## 6 lambda[4] 7.13e-2 6.43e-2 4.09e-2 3.65e-2 1.94e-2 1.47e-1 1.00 2236.
## 7 lambda[5] 8.46e-2 5.64e-2 8.65e-2 5.97e-2 3.96e-3 2.59e-1 1.00 2766.
## 8 lambda[6] 1.39e-1 1.16e-1 9.74e-2 8.56e-2 2.88e-2 3.35e-1 1.00 2830.
## 9 lambda[7] 6.88e-2 4.71e-2 7.15e-2 4.96e-2 3.26e-3 2.11e-1 1.00 1561.
## 10 lambda[8] 6.27e-2 3.69e-2 7.28e-2 3.91e-2 2.78e-3 2.04e-1 1.00 2038.
## # … with 13,020 more rows, 1 more variable: ess_tail <dbl>, and abbreviated
## # variable name ¹ess_bulk
We have some basic HMC-based validity statistics we can check.
pnbd_consol_lambmn_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_pnbd_consol_lambmn-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_lambmn-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_lambmn-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_pnbd_consol_lambmn-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
3.2 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
pnbd_consol_lambmn_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = c("lambda_mn", "alpha", "lambda[1]", "lambda[2]", "mu[1]", "mu[2]")) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and Mu Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
pnbd_consol_lambmn_stanfit %>%
rhat(pars = c("lambda_mn", "lambda", "mu")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
pnbd_consol_lambmn_stanfit %>%
neff_ratio(pars = c("lambda_mn", "lambda", "mu")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
pnbd_consol_lambmn_stanfit$draws() %>%
mcmc_acf(pars = c("lambda_mn", "alpha", "lambda[1]", "lambda[2]", "mu[1]", "mu[2]")) +
ggtitle("Autocorrelation Plot of Sample Values")3.3 Validate the Lambda-Mean Model
run_chunk <- function(sim_file, param_tbl)
run_pnbd_simulations_chunk(
sim_file, param_tbl,
start_dttm = as.POSIXct("2011-01-01"),
end_dttm = as.POSIXct("2011-12-10")
)
pnbd_consol_lambmn_valid_lst <- construct_model_validation_data(
btyd_stanfit = pnbd_consol_lambmn_stanfit,
btyd_fitdata_tbl = btyd_fitdata_tbl,
btyd_obs_stats_tbl = btyd_obs_stats_tbl,
precompute_dir = "precompute/pnbd_consol_lambmn",
precompute_key = "sims_pnbd_consol_lambmn",
run_chunk_func = run_chunk
)
pnbd_consol_lambmn_valid_lst %>% glimpse(max.level = 1)## List of 4
## $ btyd_validation_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ btyd_validsims_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ valid_custcount_plot:List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
## $ valid_tnxcount_plot :List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
Having ran the simulations, we now want to check the outputs against the observed data.
pnbd_consol_lambmn_valid_lst$valid_custcount_plot %>% print()pnbd_consol_lambmn_valid_lst$valid_tnxcount_plot %>% print()4 Fit the Fixed Prior BG/NBD Model
We now want to fit our BG/NBD model with fixed priors.
## functions {
## #include util_functions.stan
## }
##
## data {
## int<lower=1> n; // number of customers
##
## vector<lower=0>[n] t_x; // time to most recent purchase
## vector<lower=0>[n] T_cal; // total observation time
## vector<lower=0>[n] x; // number of purchases observed
##
## real<lower=0> lambda_mn; // prior mean for lambda
## real<lower=0> lambda_cv; // prior cv for lambda
##
## real<lower=0,upper=1> p_mn; // prior mean for p
## real<lower=0> p_k; // prior strength for p
## }
##
##
## transformed data {
## real<lower=0> r = 1 / (lambda_cv * lambda_cv);
## real<lower=0> alpha = 1 / (lambda_cv * lambda_cv * lambda_mn);
##
## real<lower=0> a = p_k * p_mn;
## real<lower=0> b = p_k * (1 - p_mn);
## }
##
##
## parameters {
## vector<lower=0>[n] lambda; // purchase rate
## vector<lower=0,upper=1>[n] p; // dropout probabilit
## }
##
##
## model {
## // setting priors
## lambda ~ gamma(r, alpha);
## p ~ beta (a, b);
##
## target += calculate_bgnbd_loglik(n, lambda, p, x, t_x, T_cal);
## }
We now compile this model using CmdStanR.
bgnbd_consol_fixed_stanmodel <- cmdstan_model(
"stan_code/bgnbd_fixed.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)4.1 Fit the Model
We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "bgnbd_consol_fixed"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- btyd_fitdata_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn = 0.25,
lambda_cv = 1.00,
p_mn = 0.10,
p_k = 2.00,
)
bgnbd_consol_fixed_stanfit <- bgnbd_consol_fixed_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4203,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 328.1 seconds.
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 331.9 seconds.
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 334.7 seconds.
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 343.1 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 334.4 seconds.
## Total execution time: 343.6 seconds.
bgnbd_consol_fixed_stanfit$summary()## # A tibble: 8,685 × 10
## variable mean median sd mad q5 q95 rhat ess_b…¹
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -5.92e+4 -5.92e+4 80.7 81.4 -5.93e+4 -5.90e+4 1.00 526.
## 2 lambda[1] 3.34e-1 3.26e-1 0.108 0.107 1.76e-1 5.24e-1 1.00 1353.
## 3 lambda[2] 1.68e-1 1.38e-1 0.124 0.102 2.95e-2 4.18e-1 1.00 2121.
## 4 lambda[3] 1.16e-1 9.69e-2 0.0829 0.0712 2.03e-2 2.83e-1 1.00 2247.
## 5 lambda[4] 8.07e-2 7.13e-2 0.0459 0.0402 2.28e-2 1.68e-1 1.00 2374.
## 6 lambda[5] 1.68e-1 1.03e-1 0.206 0.107 6.63e-3 5.59e-1 1.00 1970.
## 7 lambda[6] 2.15e-1 1.73e-1 0.174 0.122 3.66e-2 5.76e-1 0.999 2042.
## 8 lambda[7] 1.50e-1 8.30e-2 0.192 0.0943 5.99e-3 5.32e-1 1.00 1659.
## 9 lambda[8] 1.82e-1 8.44e-2 0.233 0.108 4.41e-3 6.63e-1 1.00 1065.
## 10 lambda[9] 2.10e-1 1.81e-1 0.134 0.117 5.73e-2 4.59e-1 1.00 2589.
## # … with 8,675 more rows, 1 more variable: ess_tail <dbl>, and abbreviated
## # variable name ¹ess_bulk
We have some basic HMC-based validity statistics we can check.
bgnbd_consol_fixed_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_bgnbd_consol_fixed-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_fixed-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_fixed-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_fixed-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
4.2 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
sample_params <- c(
"lambda[1]", "lambda[2]", "lambda[3]", "lambda[4]", "lambda[5]", "lambda[6]",
"p[1]", "p[2]", "p[3]", "p[4]", "p[5]", "p[6]"
)
bgnbd_consol_fixed_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = sample_params) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and p Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
bgnbd_consol_fixed_stanfit %>%
rhat(pars = c("lambda", "p")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
bgnbd_consol_fixed_stanfit %>%
neff_ratio(pars = c("lambda", "p")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
bgnbd_consol_fixed_stanfit$draws() %>%
mcmc_acf(pars = sample_params) +
ggtitle("Autocorrelation Plot of Sample Values")4.3 Validate the Fixed Prior BG/NBD Model
run_chunk <- function(sim_file, param_tbl)
run_bgnbd_simulations_chunk(
sim_file, param_tbl,
start_dttm = as.POSIXct("2011-01-01"),
end_dttm = as.POSIXct("2011-12-10")
)
bgnbd_consol_fixed_valid_lst <- construct_model_validation_data(
btyd_stanfit = bgnbd_consol_fixed_stanfit,
btyd_fitdata_tbl = btyd_fitdata_tbl,
btyd_obs_stats_tbl = btyd_obs_stats_tbl,
precompute_dir = "precompute/bgnbd_consol_fixed",
precompute_key = "sims_bgnbd_consol_fixed",
run_chunk_func = run_chunk
)
bgnbd_consol_fixed_valid_lst %>% glimpse(max.level = 1)## List of 4
## $ btyd_validation_tbl : tibble [8,684,000 × 4] (S3: tbl_df/tbl/data.frame)
## $ btyd_validsims_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ valid_custcount_plot:List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
## $ valid_tnxcount_plot :List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
Having ran the simulations, we now want to check the outputs against the observed data.
bgnbd_consol_fixed_valid_lst$valid_custcount_plot %>% print()bgnbd_consol_fixed_valid_lst$valid_tnxcount_plot %>% print()5 Fit the Hierarchical Means BG/NBD Model
We now want to fit our BG/NBD model with hierarchical priors for the means of both \(\lambda\) and \(p\).
## functions {
## #include util_functions.stan
## }
##
## data {
## int<lower=1> n; // number of customers
##
## vector<lower=0>[n] t_x; // time to most recent purchase
## vector<lower=0>[n] T_cal; // total observation time
## vector<lower=0>[n] x; // number of purchases observed
##
## real lambda_mn_p1; // lambda mean prior p1
## real<lower=0> lambda_mn_p2; // lambda mean prior p2
##
## real<lower=0> lambda_cv;
##
## real<lower=0> p_mn_mu; // p mean prior p1
## real<lower=0> p_mn_k; // p mean prior p2
##
## real<lower=0> p_k;
## }
##
##
## transformed data {
## real<lower=0> r = 1 / (lambda_cv * lambda_cv);
##
## real<lower=0> p_mn_a = p_mn_k * p_mn_mu;
## real<lower=0> p_mn_b = p_mn_k * (1 - p_mn_mu);
## }
##
## parameters {
## real<lower=0> lambda_mn;
## real<lower=0,upper=1> p_mn;
##
## vector<lower=0>[n] lambda; // purchase rate
## vector<lower=0,upper=1>[n] p; // dropout probabilit
## }
##
##
## transformed parameters {
## real<lower=0> alpha = 1 / (lambda_cv * lambda_cv * lambda_mn);
##
## real<lower=0> p_a = p_k * p_mn;
## real<lower=0> p_b = p_k * (1 - p_mn);
## }
##
##
## model {
## // set hyper-priors
## lambda_mn ~ lognormal(lambda_mn_p1, lambda_mn_p2);
##
## p_mn ~ beta (p_mn_a, p_mn_b);
##
## // setting priors
## lambda ~ gamma(r, alpha);
## p ~ beta (p_a, p_b);
##
## target += calculate_bgnbd_loglik(n, lambda, p, x, t_x, T_cal);
## }
We now compile this model using CmdStanR.
bgnbd_consol_hier_means_stanmodel <- cmdstan_model(
"stan_code/bgnbd_hier_means.stan",
include_paths = stan_codedir,
pedantic = TRUE,
dir = stan_modeldir
)5.1 Fit the Model
We then use this compiled model with our data to produce a fit of the data.
stan_modelname <- "bgnbd_consol_hier_means"
stanfit_prefix <- str_c("fit_", stan_modelname)
stan_data_lst <- btyd_fitdata_tbl %>%
select(customer_id, x, t_x, T_cal) %>%
compose_data(
lambda_mn_p1 = log(0.25) - 0.5 * (1.0)^2,
lambda_mn_p2 = 1.0,
lambda_cv = 1.0,
p_mn_mu = 0.1,
p_mn_k = 100,
p_k = 10
)
bgnbd_consol_hier_means_stanfit <- bgnbd_consol_hier_means_stanmodel$sample(
data = stan_data_lst,
chains = 4,
iter_warmup = 500,
iter_sampling = 500,
seed = 4204,
save_warmup = TRUE,
output_dir = stan_modeldir,
output_basename = stanfit_prefix,
)## Running MCMC with 4 chains, at most 8 in parallel...
##
## Chain 1 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 3 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 4 Iteration: 1 / 1000 [ 0%] (Warmup)
## Chain 2 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 3 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 1 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 4 Iteration: 100 / 1000 [ 10%] (Warmup)
## Chain 2 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 3 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 1 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 4 Iteration: 200 / 1000 [ 20%] (Warmup)
## Chain 2 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 4 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 2 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 1 Iteration: 300 / 1000 [ 30%] (Warmup)
## Chain 3 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 4 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 2 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 2 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 1 Iteration: 400 / 1000 [ 40%] (Warmup)
## Chain 3 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 3 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 4 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 4 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 3 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 1 Iteration: 500 / 1000 [ 50%] (Warmup)
## Chain 1 Iteration: 501 / 1000 [ 50%] (Sampling)
## Chain 2 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 4 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 3 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 1 Iteration: 600 / 1000 [ 60%] (Sampling)
## Chain 2 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 4 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 3 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 1 Iteration: 700 / 1000 [ 70%] (Sampling)
## Chain 2 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 3 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 1 Iteration: 800 / 1000 [ 80%] (Sampling)
## Chain 2 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 2 finished in 228.9 seconds.
## Chain 4 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 3 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 3 finished in 231.6 seconds.
## Chain 1 Iteration: 900 / 1000 [ 90%] (Sampling)
## Chain 4 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 4 finished in 240.7 seconds.
## Chain 1 Iteration: 1000 / 1000 [100%] (Sampling)
## Chain 1 finished in 249.5 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 237.7 seconds.
## Total execution time: 249.9 seconds.
bgnbd_consol_hier_means_stanfit$summary()## # A tibble: 8,690 × 10
## variable mean median sd mad q5 q95 rhat ess_b…¹
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 lp__ -5.40e+4 -5.40e+4 1.23e+2 1.27e+2 -5.42e+4 -5.38e+4 1.02 97.6
## 2 lambda_mn 1.25e-1 1.25e-1 2.76e-3 2.81e-3 1.21e-1 1.30e-1 1.01 395.
## 3 p_mn 1.35e-1 1.35e-1 5.95e-3 6.20e-3 1.26e-1 1.45e-1 1.03 89.1
## 4 lambda[1] 3.05e-1 2.93e-1 9.84e-2 9.77e-2 1.64e-1 4.83e-1 1.01 2120.
## 5 lambda[2] 1.24e-1 1.00e-1 9.26e-2 7.79e-2 2.04e-2 3.06e-1 1.00 1949.
## 6 lambda[3] 9.92e-2 8.19e-2 7.09e-2 6.06e-2 1.86e-2 2.38e-1 1.00 2373.
## 7 lambda[4] 7.24e-2 6.43e-2 4.20e-2 3.72e-2 1.85e-2 1.54e-1 1.00 1794.
## 8 lambda[5] 9.11e-2 5.92e-2 9.88e-2 6.14e-2 4.44e-3 2.88e-1 1.00 1893.
## 9 lambda[6] 1.46e-1 1.18e-1 1.09e-1 8.48e-2 2.41e-2 3.59e-1 1.00 2277.
## 10 lambda[7] 7.75e-2 5.10e-2 8.70e-2 5.17e-2 3.76e-3 2.38e-1 1.00 2227.
## # … with 8,680 more rows, 1 more variable: ess_tail <dbl>, and abbreviated
## # variable name ¹ess_bulk
We have some basic HMC-based validity statistics we can check.
bgnbd_consol_hier_means_stanfit$cmdstan_diagnose()## Processing csv files: /home/rstudio/workshop/stan_models/fit_bgnbd_consol_hier_means-1.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_hier_means-2.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_hier_means-3.csvWarning: non-fatal error reading adaptation data
## , /home/rstudio/workshop/stan_models/fit_bgnbd_consol_hier_means-4.csvWarning: non-fatal error reading adaptation data
##
##
## Checking sampler transitions treedepth.
## Treedepth satisfactory for all transitions.
##
## Checking sampler transitions for divergences.
## No divergent transitions found.
##
## Checking E-BFMI - sampler transitions HMC potential energy.
## E-BFMI satisfactory.
##
## Effective sample size satisfactory.
##
## Split R-hat values satisfactory all parameters.
##
## Processing complete, no problems detected.
5.2 Visual Diagnostics of the Sample Validity
Now that we have a sample from the posterior distribution we need to create a few different visualisations of the diagnostics.
sample_params <- c(
"lambda_mn", "p_mn",
"lambda[1]", "lambda[2]", "lambda[3]", "lambda[4]", "lambda[5]",
"p[1]", "p[2]", "p[3]", "p[4]", "p[5]"
)
bgnbd_consol_hier_means_stanfit$draws(inc_warmup = FALSE) %>%
mcmc_trace(pars = sample_params) +
expand_limits(y = 0) +
labs(
x = "Iteration",
y = "Value",
title = "Traceplot of Sample of Lambda and p Values"
) +
theme(axis.text.x = element_text(size = 10))A common MCMC diagnostic is \(\hat{R}\) - which is a measure of the ‘similarity’ of the chains.
bgnbd_consol_hier_means_stanfit %>%
rhat(pars = c("lambda_mn", "p_mn", "lambda", "p")) %>%
mcmc_rhat() +
ggtitle("Plot of Parameter R-hat Values")Related to this quantity is the concept of effective sample size, \(N_{eff}\), an estimate of the size of the sample from a statistical information point of view.
bgnbd_consol_hier_means_stanfit %>%
neff_ratio(pars = c("lambda_mn", "p_mn", "lambda", "p")) %>%
mcmc_neff() +
ggtitle("Plot of Parameter Effective Sample Sizes")Finally, we also want to look at autocorrelation in the chains for each parameter.
bgnbd_consol_hier_means_stanfit$draws() %>%
mcmc_acf(pars = sample_params) +
ggtitle("Autocorrelation Plot of Sample Values")5.3 Validate the Fixed Prior BG/NBD Model
run_chunk <- function(sim_file, param_tbl)
run_bgnbd_simulations_chunk(
sim_file, param_tbl,
start_dttm = as.POSIXct("2011-01-01"),
end_dttm = as.POSIXct("2011-12-10")
)
bgnbd_consol_hier_means_valid_lst <- construct_model_validation_data(
btyd_stanfit = bgnbd_consol_hier_means_stanfit,
btyd_fitdata_tbl = btyd_fitdata_tbl,
btyd_obs_stats_tbl = btyd_obs_stats_tbl,
precompute_dir = "precompute/bgnbd_consol_hier_means",
precompute_key = "sims_bgnbd_consol_hier_means",
run_chunk_func = run_chunk
)
bgnbd_consol_hier_means_valid_lst %>% glimpse(max.level = 1)## List of 4
## $ btyd_validation_tbl : tibble [8,684,000 × 4] (S3: tbl_df/tbl/data.frame)
## $ btyd_validsims_tbl : tibble [8,684,000 × 5] (S3: tbl_df/tbl/data.frame)
## $ valid_custcount_plot:List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
## $ valid_tnxcount_plot :List of 9
## ..- attr(*, "class")= chr [1:2] "gg" "ggplot"
Having ran the simulations, we now want to check the outputs against the observed data.
bgnbd_consol_hier_means_valid_lst$valid_custcount_plot %>% print()bgnbd_consol_hier_means_valid_lst$valid_tnxcount_plot %>% print()6 R Environment
options(width = 120L)
sessioninfo::session_info()## ─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────────
## setting value
## version R version 4.2.1 (2022-06-23)
## os Ubuntu 20.04.5 LTS
## system x86_64, linux-gnu
## ui RStudio
## language (EN)
## collate en_US.UTF-8
## ctype en_US.UTF-8
## tz Etc/UTC
## date 2022-11-22
## rstudio 2022.07.2+576 Spotted Wakerobin (server)
## pandoc 2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown)
##
## ─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────
## package * version date (UTC) lib source
## abind 1.4-5 2016-07-21 [1] RSPM (R 4.2.0)
## arrayhelpers 1.1-0 2020-02-04 [1] RSPM (R 4.2.0)
## assertthat 0.2.1 2019-03-21 [1] RSPM (R 4.2.0)
## backports 1.4.1 2021-12-13 [1] RSPM (R 4.2.0)
## base64enc 0.1-3 2015-07-28 [1] RSPM (R 4.2.0)
## bayesplot * 1.9.0 2022-03-10 [1] RSPM (R 4.2.0)
## bit 4.0.4 2020-08-04 [1] RSPM (R 4.2.0)
## bit64 4.0.5 2020-08-30 [1] RSPM (R 4.2.0)
## bookdown 0.29 2022-09-12 [1] RSPM (R 4.2.0)
## boot 1.3-28 2021-05-03 [2] CRAN (R 4.2.1)
## bridgesampling 1.1-2 2021-04-16 [1] RSPM (R 4.2.0)
## brms * 2.18.0 2022-11-01 [1] Github (paul-buerkner/brms@28f778d)
## Brobdingnag 1.2-9 2022-10-19 [1] RSPM (R 4.2.0)
## broom 1.0.1 2022-08-29 [1] RSPM (R 4.2.0)
## bslib 0.4.0 2022-07-16 [1] RSPM (R 4.2.0)
## cachem 1.0.6 2021-08-19 [1] RSPM (R 4.2.0)
## callr 3.7.2 2022-08-22 [1] RSPM (R 4.2.0)
## cellranger 1.1.0 2016-07-27 [1] RSPM (R 4.2.0)
## checkmate 2.1.0 2022-04-21 [1] RSPM (R 4.2.0)
## cli 3.4.1 2022-09-23 [1] RSPM (R 4.2.0)
## cmdstanr * 0.5.3 2022-11-01 [1] Github (stan-dev/cmdstanr@22b391e)
## coda 0.19-4 2020-09-30 [1] RSPM (R 4.2.0)
## codetools 0.2-18 2020-11-04 [2] CRAN (R 4.2.1)
## colorspace 2.0-3 2022-02-21 [1] RSPM (R 4.2.0)
## colourpicker 1.1.1 2021-10-04 [1] RSPM (R 4.2.0)
## conflicted * 1.1.0 2021-11-26 [1] RSPM (R 4.2.0)
## cowplot * 1.1.1 2020-12-30 [1] RSPM (R 4.2.0)
## crayon 1.5.2 2022-09-29 [1] RSPM (R 4.2.0)
## crosstalk 1.2.0 2021-11-04 [1] RSPM (R 4.2.0)
## curl 4.3.3 2022-10-06 [1] RSPM (R 4.2.0)
## data.table 1.14.4 2022-10-17 [1] RSPM (R 4.2.0)
## DBI 1.1.3 2022-06-18 [1] RSPM (R 4.2.0)
## dbplyr 2.2.1 2022-06-27 [1] RSPM (R 4.2.0)
## digest 0.6.30 2022-10-18 [1] RSPM (R 4.2.0)
## directlabels * 2021.1.13 2021-01-16 [1] RSPM (R 4.2.0)
## distributional 0.3.1 2022-09-02 [1] RSPM (R 4.2.0)
## dplyr * 1.0.10 2022-09-01 [1] RSPM (R 4.2.0)
## DT 0.26 2022-10-19 [1] RSPM (R 4.2.0)
## dygraphs 1.1.1.6 2018-07-11 [1] RSPM (R 4.2.0)
## ellipsis 0.3.2 2021-04-29 [1] RSPM (R 4.2.0)
## evaluate 0.17 2022-10-07 [1] RSPM (R 4.2.0)
## fansi 1.0.3 2022-03-24 [1] RSPM (R 4.2.0)
## farver 2.1.1 2022-07-06 [1] RSPM (R 4.2.0)
## fastmap 1.1.0 2021-01-25 [1] RSPM (R 4.2.0)
## forcats * 0.5.2 2022-08-19 [1] RSPM (R 4.2.0)
## fs * 1.5.2 2021-12-08 [1] RSPM (R 4.2.0)
## furrr * 0.3.1 2022-08-15 [1] RSPM (R 4.2.0)
## future * 1.28.0 2022-09-02 [1] RSPM (R 4.2.0)
## gamm4 0.2-6 2020-04-03 [1] RSPM (R 4.2.0)
## gargle 1.2.1 2022-09-08 [1] RSPM (R 4.2.0)
## generics 0.1.3 2022-07-05 [1] RSPM (R 4.2.0)
## ggdist 3.2.0 2022-07-19 [1] RSPM (R 4.2.0)
## ggplot2 * 3.3.6 2022-05-03 [1] RSPM (R 4.2.0)
## ggridges 0.5.4 2022-09-26 [1] RSPM (R 4.2.0)
## globals 0.16.1 2022-08-28 [1] RSPM (R 4.2.0)
## glue * 1.6.2 2022-02-24 [1] RSPM (R 4.2.0)
## googledrive 2.0.0 2021-07-08 [1] RSPM (R 4.2.0)
## googlesheets4 1.0.1 2022-08-13 [1] RSPM (R 4.2.0)
## gridExtra 2.3 2017-09-09 [1] RSPM (R 4.2.0)
## gtable 0.3.1 2022-09-01 [1] RSPM (R 4.2.0)
## gtools 3.9.3 2022-07-11 [1] RSPM (R 4.2.0)
## haven 2.5.1 2022-08-22 [1] RSPM (R 4.2.0)
## highr 0.9 2021-04-16 [1] RSPM (R 4.2.0)
## hms 1.1.2 2022-08-19 [1] RSPM (R 4.2.0)
## htmltools 0.5.3 2022-07-18 [1] RSPM (R 4.2.0)
## htmlwidgets 1.5.4 2021-09-08 [1] RSPM (R 4.2.0)
## httpuv 1.6.6 2022-09-08 [1] RSPM (R 4.2.0)
## httr 1.4.4 2022-08-17 [1] RSPM (R 4.2.0)
## igraph 1.3.5 2022-09-22 [1] RSPM (R 4.2.0)
## inline 0.3.19 2021-05-31 [1] RSPM (R 4.2.0)
## jquerylib 0.1.4 2021-04-26 [1] RSPM (R 4.2.0)
## jsonlite 1.8.3 2022-10-21 [1] RSPM (R 4.2.0)
## knitr 1.40 2022-08-24 [1] RSPM (R 4.2.0)
## labeling 0.4.2 2020-10-20 [1] RSPM (R 4.2.0)
## later 1.3.0 2021-08-18 [1] RSPM (R 4.2.0)
## lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.1)
## lifecycle 1.0.3 2022-10-07 [1] RSPM (R 4.2.0)
## listenv 0.8.0 2019-12-05 [1] RSPM (R 4.2.0)
## lme4 1.1-30 2022-07-08 [1] RSPM (R 4.2.0)
## loo 2.5.1 2022-03-24 [1] RSPM (R 4.2.0)
## lubridate 1.8.0 2021-10-07 [1] RSPM (R 4.2.0)
## magrittr * 2.0.3 2022-03-30 [1] RSPM (R 4.2.0)
## markdown 1.2 2022-10-19 [1] RSPM (R 4.2.0)
## MASS 7.3-57 2022-04-22 [2] CRAN (R 4.2.1)
## Matrix 1.5-1 2022-09-13 [1] RSPM (R 4.2.0)
## matrixStats 0.62.0 2022-04-19 [1] RSPM (R 4.2.0)
## memoise 2.0.1 2021-11-26 [1] RSPM (R 4.2.0)
## mgcv 1.8-40 2022-03-29 [2] CRAN (R 4.2.1)
## mime 0.12 2021-09-28 [1] RSPM (R 4.2.0)
## miniUI 0.1.1.1 2018-05-18 [1] RSPM (R 4.2.0)
## minqa 1.2.5 2022-10-19 [1] RSPM (R 4.2.0)
## modelr 0.1.9 2022-08-19 [1] RSPM (R 4.2.0)
## munsell 0.5.0 2018-06-12 [1] RSPM (R 4.2.0)
## mvtnorm 1.1-3 2021-10-08 [1] RSPM (R 4.2.0)
## nlme 3.1-157 2022-03-25 [2] CRAN (R 4.2.1)
## nloptr 2.0.3 2022-05-26 [1] RSPM (R 4.2.0)
## parallelly 1.32.1 2022-07-21 [1] RSPM (R 4.2.0)
## pillar 1.8.1 2022-08-19 [1] RSPM (R 4.2.0)
## pkgbuild 1.3.1 2021-12-20 [1] RSPM (R 4.2.0)
## pkgconfig 2.0.3 2019-09-22 [1] RSPM (R 4.2.0)
## plyr 1.8.7 2022-03-24 [1] RSPM (R 4.2.0)
## posterior * 1.3.1 2022-09-06 [1] RSPM (R 4.2.0)
## prettyunits 1.1.1 2020-01-24 [1] RSPM (R 4.2.0)
## processx 3.8.0 2022-10-26 [1] RSPM (R 4.2.0)
## projpred 2.2.1 2022-09-20 [1] RSPM (R 4.2.0)
## promises 1.2.0.1 2021-02-11 [1] RSPM (R 4.2.0)
## ps 1.7.2 2022-10-26 [1] RSPM (R 4.2.0)
## purrr * 0.3.5 2022-10-06 [1] RSPM (R 4.2.0)
## quadprog 1.5-8 2019-11-20 [1] RSPM (R 4.2.0)
## R6 2.5.1 2021-08-19 [1] RSPM (R 4.2.0)
## Rcpp * 1.0.9 2022-07-08 [1] RSPM (R 4.2.0)
## RcppParallel 5.1.5 2022-01-05 [1] RSPM (R 4.2.0)
## readr * 2.1.3 2022-10-01 [1] RSPM (R 4.2.0)
## readxl 1.4.1 2022-08-17 [1] RSPM (R 4.2.0)
## reprex 2.0.2 2022-08-17 [1] RSPM (R 4.2.0)
## reshape2 1.4.4 2020-04-09 [1] RSPM (R 4.2.0)
## rlang * 1.0.6 2022-09-24 [1] RSPM (R 4.2.0)
## rmarkdown 2.17 2022-10-07 [1] RSPM (R 4.2.0)
## rmdformats 1.0.4 2022-05-17 [1] RSPM (R 4.2.0)
## rstan 2.26.13 2022-11-01 [1] local
## rstantools 2.2.0 2022-04-08 [1] RSPM (R 4.2.0)
## rstudioapi 0.14 2022-08-22 [1] RSPM (R 4.2.0)
## rvest 1.0.3 2022-08-19 [1] RSPM (R 4.2.0)
## sass 0.4.2 2022-07-16 [1] RSPM (R 4.2.0)
## scales * 1.2.1 2022-08-20 [1] RSPM (R 4.2.0)
## sessioninfo 1.2.2 2021-12-06 [1] RSPM (R 4.2.0)
## shiny 1.7.3 2022-10-25 [1] RSPM (R 4.2.0)
## shinyjs 2.1.0 2021-12-23 [1] RSPM (R 4.2.0)
## shinystan 2.6.0 2022-03-03 [1] RSPM (R 4.2.0)
## shinythemes 1.2.0 2021-01-25 [1] RSPM (R 4.2.0)
## StanHeaders 2.26.13 2022-11-01 [1] local
## stringi 1.7.8 2022-07-11 [1] RSPM (R 4.2.0)
## stringr * 1.4.1 2022-08-20 [1] RSPM (R 4.2.0)
## svUnit 1.0.6 2021-04-19 [1] RSPM (R 4.2.0)
## tensorA 0.36.2 2020-11-19 [1] RSPM (R 4.2.0)
## threejs 0.3.3 2020-01-21 [1] RSPM (R 4.2.0)
## tibble * 3.1.8 2022-07-22 [1] RSPM (R 4.2.0)
## tidybayes * 3.0.2.9000 2022-11-01 [1] Github (mjskay/tidybayes@1efbdef)
## tidyr * 1.2.1 2022-09-08 [1] RSPM (R 4.2.0)
## tidyselect 1.2.0 2022-10-10 [1] RSPM (R 4.2.0)
## tidyverse * 1.3.2 2022-07-18 [1] RSPM (R 4.2.0)
## tzdb 0.3.0 2022-03-28 [1] RSPM (R 4.2.0)
## utf8 1.2.2 2021-07-24 [1] RSPM (R 4.2.0)
## V8 4.2.1 2022-08-07 [1] RSPM (R 4.2.0)
## vctrs 0.5.0 2022-10-22 [1] RSPM (R 4.2.0)
## vroom 1.6.0 2022-09-30 [1] RSPM (R 4.2.0)
## withr 2.5.0 2022-03-03 [1] RSPM (R 4.2.0)
## xfun 0.34 2022-10-18 [1] RSPM (R 4.2.0)
## xml2 1.3.3 2021-11-30 [1] RSPM (R 4.2.0)
## xtable 1.8-4 2019-04-21 [1] RSPM (R 4.2.0)
## xts 0.12.2 2022-10-16 [1] RSPM (R 4.2.0)
## yaml 2.3.6 2022-10-18 [1] RSPM (R 4.2.0)
## zoo 1.8-11 2022-09-17 [1] RSPM (R 4.2.0)
##
## [1] /usr/local/lib/R/site-library
## [2] /usr/local/lib/R/library
##
## ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
options(width = 80L)